# Copyright 2023–2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Parameter mappings and transformation hooks for checkpoint conversion.

This module defines the necessary components to convert model checkpoints between
MaxText and Hugging Face formats for various architectures (e.g., Gemma, Qwen).
It provides two key types of mappings for each model:

1.  **Parameter Name Mappings (`PARAM_MAPPING`)**: Dictionaries that map the string
    name of a parameter in a MaxText checkpoint to its corresponding name in a
    Hugging Face checkpoint. These mappings are generated by functions like
    `GEMMA2_MAXTEXT_TO_HF_PARAM_MAPPING`.

2.  **Hook Functions (`HOOK_FNS`)**: Dictionaries that map a MaxText parameter
    name to a specific transformation function (a "hook"). These hooks handle
    the actual value conversion, which can include operations like reshaping,
    transposing, scaling, or padding tensors to match the target format's
    requirements. These are generated by functions like
    `GEMMA2_MAXTEXT_TO_HF_PARAM_HOOK_FN`.

The main conversion script uses these mappings to systematically transform each
parameter from the source checkpoint and build the target checkpoint.
"""

import warnings
import numpy as np

import jax
import jax.numpy as jnp


def GEMMA3_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False):
  """Generates a parameter mapping from MaxText to Hugging Face for Gemma3.

  This function creates a dictionary that maps the parameter names from a
  MaxText Gemma3 checkpoint to their corresponding names in the Hugging Face
  `Gemma3ForCausalLM` format. It handles both the text and vision components
  of the model.

  Args:
    config (dict): The Hugging Face model configuration dictionary, which must
      contain 'text_config' and 'vision_config' sub-dictionaries.
    scan_layers (bool, optional): If True, generates mappings for scanned
      layers, where multiple layers are stacked into a single tensor. If False,
      generates mappings for individual, unscanned layers. Defaults to False.

  Returns:
    dict: A mapping where keys are MaxText parameter names and values are the
      corresponding Hugging Face parameter names. For scanned text layers, the
      value is a list of Hugging Face names.
  """
  tcfg = config["text_config"]
  vcfg = config["vision_config"]
  Ndec = tcfg["num_hidden_layers"]
  Nvis = vcfg["num_hidden_layers"]

  # pylint: disable=line-too-long
  mapping = {
      # Embedding & final norm
      "params-token_embedder-embedding": "model.language_model.embed_tokens.weight",
      "params-decoder-decoder_norm-scale": "model.language_model.norm.weight",
      # Vision embed & pos
      "params-vision_encoder-Gemma3VisionEncoderLayer_0-embedding-kernel": "model.vision_tower.vision_model.embeddings.patch_embedding.weight",
      "params-vision_encoder-Gemma3VisionEncoderLayer_0-embedding-bias": "model.vision_tower.vision_model.embeddings.patch_embedding.bias",
      "params-vision_encoder-Gemma3VisionEncoderLayer_0-pos_embedding": "model.vision_tower.vision_model.embeddings.position_embedding.weight",
      "params-vision_encoder-Gemma3VisionEncoderLayer_0-Transformer-encoder_norm-scale": "model.vision_tower.vision_model.post_layernorm.weight",
      "params-vision_encoder-Gemma3VisionEncoderLayer_0-Transformer-encoder_norm-bias": "model.vision_tower.vision_model.post_layernorm.bias",
      # Multi-modal projector
      "params-vision_encoder-VisionEmbedder_0-mm_input_projection-w": "model.multi_modal_projector.mm_input_projection_weight",
      "params-vision_encoder-VisionEmbedder_0-mm_soft_embedding_norm-scale": "model.multi_modal_projector.mm_soft_emb_norm.weight",
  }

  vision_params = [
      ("LayerNorm_0-scale", "layer_norm1.weight"),
      ("LayerNorm_0-bias", "layer_norm1.bias"),
      ("LayerNorm_1-scale", "layer_norm2.weight"),
      ("LayerNorm_1-bias", "layer_norm2.bias"),
      ("MultiHeadDotProductAttention_0-query-kernel", "self_attn.q_proj.weight"),
      ("MultiHeadDotProductAttention_0-query-bias", "self_attn.q_proj.bias"),
      ("MultiHeadDotProductAttention_0-key-kernel", "self_attn.k_proj.weight"),
      ("MultiHeadDotProductAttention_0-key-bias", "self_attn.k_proj.bias"),
      ("MultiHeadDotProductAttention_0-value-kernel", "self_attn.v_proj.weight"),
      ("MultiHeadDotProductAttention_0-value-bias", "self_attn.v_proj.bias"),
      ("MultiHeadDotProductAttention_0-out-kernel", "self_attn.out_proj.weight"),
      ("MultiHeadDotProductAttention_0-out-bias", "self_attn.out_proj.bias"),
      ("MlpBlockViT_0-Dense_0-kernel", "mlp.fc1.weight"),
      ("MlpBlockViT_0-Dense_0-bias", "mlp.fc1.bias"),
      ("MlpBlockViT_0-Dense_1-kernel", "mlp.fc2.weight"),
      ("MlpBlockViT_0-Dense_1-bias", "mlp.fc2.bias"),
  ]

  # Vision layers mapping
  if scan_layers:
    for i in range(Nvis):
      for mx, hf in vision_params:
        key = f"params-vision_encoder-Gemma3VisionEncoderLayer_0-Transformer-encoderblock-{mx}"
        mapping[key] = f"model.vision_tower.vision_model.encoder.layers.{i}.{hf}"
  else:
    for i in range(Nvis):
      for mx, hf in vision_params:
        key = f"params-vision_encoder-Gemma3VisionEncoderLayer_0-Transformer-encoderblock_{i}-{mx}"
        mapping[key] = f"model.vision_tower.vision_model.encoder.layers.{i}.{hf}"

  # Text decoder mapping
  text_params = [
      ("pre_self_attention_norm-scale", "input_layernorm.weight"),
      ("post_self_attention_norm-scale", "post_attention_layernorm.weight"),
      ("self_attention-query_norm-scale", "self_attn.q_norm.weight"),
      ("self_attention-key_norm-scale", "self_attn.k_norm.weight"),
      ("pre_ffw_norm-scale", "pre_feedforward_layernorm.weight"),
      ("post_ffw_norm-scale", "post_feedforward_layernorm.weight"),
      ("self_attention-query-kernel", "self_attn.q_proj.weight"),
      ("self_attention-key-kernel", "self_attn.k_proj.weight"),
      ("self_attention-value-kernel", "self_attn.v_proj.weight"),
      ("self_attention-out-kernel", "self_attn.o_proj.weight"),
      ("mlp-wi_0-kernel", "mlp.gate_proj.weight"),
      ("mlp-wi_1-kernel", "mlp.up_proj.weight"),
      ("mlp-wo-kernel", "mlp.down_proj.weight"),
  ]

  if scan_layers:
    for mx, hf in text_params:
      key = f"params-decoder-layers-{mx}"
      mapping[key] = [f"model.language_model.layers.{i}.{hf}" for i in range(Ndec)]
  else:
    for i in range(Ndec):
      for mx, hf in text_params:
        key = f"params-decoder-layers_{i}-{mx}"
        mapping[key] = f"model.language_model.layers.{i}.{hf}"

  return mapping


def GEMMA3_MAXTEXT_TO_HF_PARAM_HOOK_FN(config, maxtext_config, scan_layers=False, saving_to_hf=False):
  """Hook functions for Gemma3 parameter conversion.

  This function provides a dictionary of transformation functions (hooks) for
  converting Gemma3 model parameters between MaxText and Hugging Face formats.
  It handles embedding padding/scaling, RMSNorm scaling, kernel reshaping, and
  vision-specific tensor manipulations.

  Args:
    config (dict): The Hugging Face model configuration dictionary.
    scan_layers (bool, optional): Whether the model uses scanned layers.
      Defaults to False.
    saving_to_hf (bool, optional): The direction of conversion. True for
      MaxText to Hugging Face, False for the reverse. Defaults to False.

  Returns:
    dict: A dictionary mapping MaxText parameter names to their corresponding
      transformation functions.
  """
  hooks = {}

  # ---- Embedding pad & scale ----
  def pad_and_scale_embedding(input_tensor, target_shape):
    source_vocab_size, _ = input_tensor.shape
    target_vocab_size, target_hidden_size = target_shape

    # MaxText embedding = original_embedding * sqrt(hidden_size)
    # HF embedding = original_embedding (HF model forward pass applies scaling)
    # Note: config["hidden_size"] is the HF hidden size from the HF config object
    normalizer = np.dtype("bfloat16").type(config["text_config"]["hidden_size"] ** 0.5)

    # Apply scaling first
    if saving_to_hf:  # MaxText to HF
      scaled_tensor = (input_tensor / normalizer).astype(input_tensor.dtype)
    else:  # HF to MaxText
      scaled_tensor = (input_tensor * normalizer).astype(input_tensor.dtype)

    # Handle padding/truncation
    if source_vocab_size > target_vocab_size:
      warnings.warn(
          f"source vocab={source_vocab_size} > target vocab={target_vocab_size}, truncate output layer for MaxText."
      )
      output_tensor = scaled_tensor[:target_vocab_size, :]
    elif source_vocab_size < target_vocab_size:
      warnings.warn(f"source vocab={source_vocab_size} < target vocab={target_vocab_size}, pad output layer for MaxText.")
      padding_shape = (target_vocab_size - source_vocab_size, target_hidden_size)
      # Use jnp.zeros for JAX arrays, np.zeros for numpy arrays
      padding = (
          jnp.zeros(padding_shape, dtype=scaled_tensor.dtype)
          if isinstance(scaled_tensor, jax.Array)
          else np.zeros(padding_shape, dtype=scaled_tensor.dtype)
      )
      output_tensor = (
          jnp.concatenate([scaled_tensor, padding], axis=0)
          if isinstance(scaled_tensor, jax.Array)
          else np.concatenate([scaled_tensor, padding], axis=0)
      )
    else:  # Vocab sizes match
      output_tensor = scaled_tensor

    return output_tensor

  # ---- RMSNorm scale ----
  def scale_rmsnorm(x, target_shape):
    # MaxText norm = HF norm +1; HF norm = MaxText norm -1
    if saving_to_hf:
      return (x - 1.0).reshape(target_shape)
    return (x + 1.0).reshape(target_shape)

  # ---- Generic reshape ----
  def reshape_kernel(x, target_shape):
    if saving_to_hf:
      flipped = np.flip(np.array(target_shape))
      return x.reshape(flipped).T
    else:
      return x.T.reshape(target_shape)

  # ---- Vision reshape ----
  def vis_bias(x, target_shape):
    if saving_to_hf:
      return x.flatten()
    else:
      return x.reshape(target_shape)

  def vision_patch(x, target_shape):
    if saving_to_hf:
      return x.transpose(3, 2, 0, 1)
    else:
      return x.transpose(2, 3, 1, 0)

  def pos_embed(x, target_shape):
    if saving_to_hf:
      return x.squeeze(0)
    return x[None, :, :]

  # ---Embedding & final norm---
  hooks["params-token_embedder-embedding"] = pad_and_scale_embedding
  hooks["params-decoder-decoder_norm-scale"] = scale_rmsnorm
  # [1, 4096, 1152]
  hooks["params-vision_encoder-Gemma3VisionEncoderLayer_0-embedding-kernel"] = vision_patch
  hooks["params-vision_encoder-Gemma3VisionEncoderLayer_0-pos_embedding"] = pos_embed

  hooks["params-vision_encoder-VisionEmbedder_0-mm_input_projection-w"] = lambda x, _: x
  hooks["params-vision_encoder-VisionEmbedder_0-mm_soft_embedding_norm-scale"] = scale_rmsnorm

  # Text layers
  tc = config.get("text_config", {})
  nlayers = tc.get("num_hidden_layers", 0)
  layer_ids = [None] if scan_layers else list(range(nlayers))
  for i in layer_ids:
    pref = f"params-decoder-layers_{i}-" if i is not None else "params-decoder-layers-"
    # Attention Q/K/V/O
    hooks[pref + "self_attention-query-kernel"] = reshape_kernel
    hooks[pref + "self_attention-key-kernel"] = reshape_kernel
    hooks[pref + "self_attention-value-kernel"] = reshape_kernel
    hooks[pref + "self_attention-out-kernel"] = reshape_kernel
    # Norm scales
    for nm in [
        "pre_self_attention_norm-scale",
        "post_self_attention_norm-scale",
        "self_attention-query_norm-scale",
        "self_attention-key_norm-scale",
        "pre_ffw_norm-scale",
        "post_ffw_norm-scale",
    ]:
      hooks[pref + nm] = scale_rmsnorm
    # MLP
    hooks[pref + "mlp-wi_0-kernel"] = reshape_kernel
    hooks[pref + "mlp-wi_1-kernel"] = reshape_kernel
    hooks[pref + "mlp-wo-kernel"] = reshape_kernel

  # Vision layers
  vc = config.get("vision_config", {})
  nvis = vc.get("num_hidden_layers", 0)
  vision_layer_ids = [None] if scan_layers else list(range(nvis))
  for i in vision_layer_ids:
    base = (
        f"params-vision_encoder-Gemma3VisionEncoderLayer_0-Transformer-encoderblock_{i}-"
        if i is not None
        else "params-vision_encoder-Gemma3VisionEncoderLayer_0-Transformer-encoderblock-"
    )
    # Attention kernels & biases
    for qkv in ["query", "key", "value"]:
      hooks[base + f"MultiHeadDotProductAttention_0-{qkv}-kernel"] = reshape_kernel
      hooks[base + f"MultiHeadDotProductAttention_0-{qkv}-bias"] = vis_bias
    # [1152, 1152] -> [16, 72, 1152]
    hooks[base + "MultiHeadDotProductAttention_0-out-kernel"] = reshape_kernel
    hooks[base + "MultiHeadDotProductAttention_0-out-bias"] = vis_bias
    # MLP ViT kernels & biases
    for dense in ["Dense_0", "Dense_1"]:
      hooks[base + f"MlpBlockViT_0-{dense}-kernel"] = reshape_kernel

  return hooks


def GEMMA2_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False):
  """Returns mapping between MaxText and HuggingFace Gemma2 weight paths.

  Args:
      config (dict): Model configuration dictionary containing at least
        'num_hidden_layers'.
      scan_layers (bool, optional): Whether the MaxText model uses layer
        scanning optimization. When True, decoder layers are stacked into a
        single tensor. Defaults to False.

  Returns:
      dict: A mapping where keys are MaxText parameter paths and values are
        either single strings (HF parameter path) for unscanned parameters or
        lists of strings (HF parameter paths) for stacked layers when
        `scan_layers=True`.

  Notes:
      - MaxText uses a paired layer approach where two HF decoder layers are
        treated as one MaxText decoder layer.
      - MaxText layer `i` corresponds to HF layers `2i` and `2i+1`.
      - Local components map to even-numbered HF decoder layers (0, 2, 4...).
      - Global components map to odd-numbered HF decoder layers (1, 3, 5...).
  """

  nlayers = config["num_hidden_layers"]
  mapping = {
      "params-token_embedder-embedding": "model.embed_tokens.weight",
      "params-decoder-decoder_norm-scale": "model.norm.weight",
  }
  if scan_layers:
    mapping = {
        **mapping,
        "params-decoder-layers-pre_self_attention_norm_global-scale": [
            f"model.layers.{i}.input_layernorm.weight" for i in range(1, nlayers, 2)
        ],
        "params-decoder-layers-mlp_global-wo-kernel": [
            f"model.layers.{i}.mlp.down_proj.weight" for i in range(1, nlayers, 2)
        ],
        "params-decoder-layers-mlp_global-wi_1-kernel": [
            f"model.layers.{i}.mlp.up_proj.weight" for i in range(1, nlayers, 2)
        ],
        "params-decoder-layers-mlp_global-wi_0-kernel": [
            f"model.layers.{i}.mlp.gate_proj.weight" for i in range(1, nlayers, 2)
        ],
        "params-decoder-layers-post_self_attention_norm_global-scale": [
            f"model.layers.{i}.post_attention_layernorm.weight" for i in range(1, nlayers, 2)
        ],
        "params-decoder-layers-post_ffw_norm_global-scale": [
            f"model.layers.{i}.post_feedforward_layernorm.weight" for i in range(1, nlayers, 2)
        ],
        "params-decoder-layers-pre_ffw_norm_global-scale": [
            f"model.layers.{i}.pre_feedforward_layernorm.weight" for i in range(1, nlayers, 2)
        ],
        "params-decoder-layers-self_attention_global-key-kernel": [
            f"model.layers.{i}.self_attn.k_proj.weight" for i in range(1, nlayers, 2)
        ],
        "params-decoder-layers-self_attention_global-out-kernel": [
            f"model.layers.{i}.self_attn.o_proj.weight" for i in range(1, nlayers, 2)
        ],
        "params-decoder-layers-self_attention_global-query-kernel": [
            f"model.layers.{i}.self_attn.q_proj.weight" for i in range(1, nlayers, 2)
        ],
        "params-decoder-layers-self_attention_global-value-kernel": [
            f"model.layers.{i}.self_attn.v_proj.weight" for i in range(1, nlayers, 2)
        ],
        "params-decoder-layers-pre_self_attention_norm_local-scale": [
            f"model.layers.{i}.input_layernorm.weight" for i in range(0, nlayers, 2)
        ],
        "params-decoder-layers-mlp_local-wo-kernel": [
            f"model.layers.{i}.mlp.down_proj.weight" for i in range(0, nlayers, 2)
        ],
        "params-decoder-layers-mlp_local-wi_1-kernel": [
            f"model.layers.{i}.mlp.up_proj.weight" for i in range(0, nlayers, 2)
        ],
        "params-decoder-layers-mlp_local-wi_0-kernel": [
            f"model.layers.{i}.mlp.gate_proj.weight" for i in range(0, nlayers, 2)
        ],
        "params-decoder-layers-post_self_attention_norm_local-scale": [
            f"model.layers.{i}.post_attention_layernorm.weight" for i in range(0, nlayers, 2)
        ],
        "params-decoder-layers-post_ffw_norm_local-scale": [
            f"model.layers.{i}.post_feedforward_layernorm.weight" for i in range(0, nlayers, 2)
        ],
        "params-decoder-layers-pre_ffw_norm_local-scale": [
            f"model.layers.{i}.pre_feedforward_layernorm.weight" for i in range(0, nlayers, 2)
        ],
        "params-decoder-layers-self_attention_local-key-kernel": [
            f"model.layers.{i}.self_attn.k_proj.weight" for i in range(0, nlayers, 2)
        ],
        "params-decoder-layers-self_attention_local-out-kernel": [
            f"model.layers.{i}.self_attn.o_proj.weight" for i in range(0, nlayers, 2)
        ],
        "params-decoder-layers-self_attention_local-query-kernel": [
            f"model.layers.{i}.self_attn.q_proj.weight" for i in range(0, nlayers, 2)
        ],
        "params-decoder-layers-self_attention_local-value-kernel": [
            f"model.layers.{i}.self_attn.v_proj.weight" for i in range(0, nlayers, 2)
        ],
    }
  # Case 2: scan_layer=False
  else:
    for maxtext_layer_idx in range(0, nlayers // 2):
      local_layer_idx = maxtext_layer_idx * 2
      global_layer_idx = maxtext_layer_idx * 2 + 1
      # pylint: disable=line-too-long
      layer_mapping = {
          f"params-decoder-layers_{maxtext_layer_idx}-pre_self_attention_norm_global-scale": f"model.layers.{global_layer_idx}.input_layernorm.weight",
          f"params-decoder-layers_{maxtext_layer_idx}-mlp_global-wo-kernel": f"model.layers.{global_layer_idx}.mlp.down_proj.weight",
          f"params-decoder-layers_{maxtext_layer_idx}-mlp_global-wi_1-kernel": f"model.layers.{global_layer_idx}.mlp.up_proj.weight",
          f"params-decoder-layers_{maxtext_layer_idx}-mlp_global-wi_0-kernel": f"model.layers.{global_layer_idx}.mlp.gate_proj.weight",
          f"params-decoder-layers_{maxtext_layer_idx}-post_self_attention_norm_global-scale": f"model.layers.{global_layer_idx}.post_attention_layernorm.weight",
          f"params-decoder-layers_{maxtext_layer_idx}-post_ffw_norm_global-scale": f"model.layers.{global_layer_idx}.post_feedforward_layernorm.weight",
          f"params-decoder-layers_{maxtext_layer_idx}-pre_ffw_norm_global-scale": f"model.layers.{global_layer_idx}.pre_feedforward_layernorm.weight",
          f"params-decoder-layers_{maxtext_layer_idx}-self_attention_global-key-kernel": f"model.layers.{global_layer_idx}.self_attn.k_proj.weight",
          f"params-decoder-layers_{maxtext_layer_idx}-self_attention_global-out-kernel": f"model.layers.{global_layer_idx}.self_attn.o_proj.weight",
          f"params-decoder-layers_{maxtext_layer_idx}-self_attention_global-query-kernel": f"model.layers.{global_layer_idx}.self_attn.q_proj.weight",
          f"params-decoder-layers_{maxtext_layer_idx}-self_attention_global-value-kernel": f"model.layers.{global_layer_idx}.self_attn.v_proj.weight",
          f"params-decoder-layers_{maxtext_layer_idx}-pre_self_attention_norm_local-scale": f"model.layers.{local_layer_idx}.input_layernorm.weight",
          f"params-decoder-layers_{maxtext_layer_idx}-mlp_local-wo-kernel": f"model.layers.{local_layer_idx}.mlp.down_proj.weight",
          f"params-decoder-layers_{maxtext_layer_idx}-mlp_local-wi_1-kernel": f"model.layers.{local_layer_idx}.mlp.up_proj.weight",
          f"params-decoder-layers_{maxtext_layer_idx}-mlp_local-wi_0-kernel": f"model.layers.{local_layer_idx}.mlp.gate_proj.weight",
          f"params-decoder-layers_{maxtext_layer_idx}-post_self_attention_norm_local-scale": f"model.layers.{local_layer_idx}.post_attention_layernorm.weight",
          f"params-decoder-layers_{maxtext_layer_idx}-post_ffw_norm_local-scale": f"model.layers.{local_layer_idx}.post_feedforward_layernorm.weight",
          f"params-decoder-layers_{maxtext_layer_idx}-pre_ffw_norm_local-scale": f"model.layers.{local_layer_idx}.pre_feedforward_layernorm.weight",
          f"params-decoder-layers_{maxtext_layer_idx}-self_attention_local-key-kernel": f"model.layers.{local_layer_idx}.self_attn.k_proj.weight",
          f"params-decoder-layers_{maxtext_layer_idx}-self_attention_local-out-kernel": f"model.layers.{local_layer_idx}.self_attn.o_proj.weight",
          f"params-decoder-layers_{maxtext_layer_idx}-self_attention_local-query-kernel": f"model.layers.{local_layer_idx}.self_attn.q_proj.weight",
          f"params-decoder-layers_{maxtext_layer_idx}-self_attention_local-value-kernel": f"model.layers.{local_layer_idx}.self_attn.v_proj.weight",
      }
      mapping = {**mapping, **layer_mapping}
  return mapping


def GEMMA2_MAXTEXT_TO_HF_PARAM_HOOK_FN(config, maxtext_config, scan_layers=False, saving_to_hf=False):
  """Creates parameter transformation functions for Gemma2 conversion.

  This function generates a mapping of transformation functions that handle the
  necessary conversions between MaxText and HuggingFace parameter formats for
  Gemma2, including operations like padding, reshaping, and scaling.

  Args:
      config (dict): Model configuration dictionary that must contain:
          - num_hidden_layers (int): Number of layers in the model.
          - head_dim (int): Dimension of attention heads.
          - hidden_size (int): Model's hidden dimension size.
      scan_layers (bool, optional): Controls the output format for layer
        parameters. True for batched, False for individual. Defaults to False.
      saving_to_hf (bool, optional): Determines the direction of transformation.
        True for MaxText to HuggingFace, False for the reverse. Defaults to
        False.

  Returns:
      dict: A mapping from MaxText parameter names to transformation functions.
        The value can be a single function or a list of functions to be
        applied sequentially.
  """
  nlayers = config["num_hidden_layers"]

  def pad_hf_embedding_layer(input_tensor, target_shape):
    """Pads/unpads and scales the embedding layer.

    Note:
        HF embedding weights shape =  [256000, d_model]
        MaxText embedding weights shape = [256128, d_model]
        MaxText pads Gemma2 embedding to 256128 for better performance.
    """
    # TODO(wenxindongwork), Perhaps, this dtype should be the activation dtype
    normalizer = np.dtype("float32").type(config["hidden_size"] ** 0.5)

    def to_hf():
      target_tensor = input_tensor[: target_shape[0], : target_shape[1]]
      target_tensor = target_tensor / normalizer
      target_tensor = target_tensor.astype(input_tensor.dtype)
      return target_tensor

    def from_hf():
      target_tensor = np.zeros(target_shape, dtype=input_tensor.dtype)
      target_tensor[: input_tensor.shape[0], : input_tensor.shape[1]] = input_tensor
      target_tensor = target_tensor * normalizer
      target_tensor = target_tensor.astype(input_tensor.dtype)
      return target_tensor

    if saving_to_hf:
      return to_hf()
    else:
      return from_hf()

  def reshape_kernel(input_tensor, target_shape):
    def to_hf():
      flipped_target_shape = np.flip(np.array(target_shape))
      return input_tensor.reshape(flipped_target_shape).T

    def from_hf():
      return input_tensor.T.reshape(target_shape)

    if saving_to_hf:
      return to_hf()
    else:
      return from_hf()

  def scale_rmsnorm_layer(input_tensor, target_shape):
    def to_hf():
      return (input_tensor - 1.0).reshape(target_shape)

    def from_hf():
      return (input_tensor + 1.0).reshape(target_shape)

    if saving_to_hf:
      return to_hf()
    else:
      return from_hf()

  def scale_query_layer(input_tensor, target_shape):
    def to_hf():
      depth_scale = np.dtype("float32").type(np.sqrt(config["head_dim"]))
      return (input_tensor * depth_scale).astype(input_tensor.dtype)

    def from_hf():
      depth_scale = np.dtype("float32").type(1 / np.sqrt(config["head_dim"]))
      return (input_tensor * depth_scale).astype(input_tensor.dtype)

    if saving_to_hf:
      return to_hf()
    else:
      return from_hf()

  mapping = {
      "params-token_embedder-embedding": pad_hf_embedding_layer,
      "params-decoder-decoder_norm-scale": scale_rmsnorm_layer,
  }
  if scan_layers:
    mapping = {
        **mapping,
        "params-decoder-layers-self_attention_global-query-kernel": [
            reshape_kernel,
            scale_query_layer,
        ],
        "params-decoder-layers-self_attention_local-query-kernel": [
            reshape_kernel,
            scale_query_layer,
        ],
        "params-decoder-layers-self_attention_global-key-kernel": reshape_kernel,
        "params-decoder-layers-self_attention_local-key-kernel": reshape_kernel,
        "params-decoder-layers-self_attention_global-value-kernel": reshape_kernel,
        "params-decoder-layers-self_attention_local-value-kernel": reshape_kernel,
        "params-decoder-layers-mlp_global-wo-kernel": reshape_kernel,
        "params-decoder-layers-mlp_global-wi_1-kernel": reshape_kernel,
        "params-decoder-layers-mlp_global-wi_0-kernel": reshape_kernel,
        "params-decoder-layers-self_attention_global-out-kernel": reshape_kernel,
        "params-decoder-layers-mlp_local-wo-kernel": reshape_kernel,
        "params-decoder-layers-mlp_local-wi_1-kernel": reshape_kernel,
        "params-decoder-layers-mlp_local-wi_0-kernel": reshape_kernel,
        "params-decoder-layers-self_attention_local-out-kernel": reshape_kernel,
        "params-decoder-layers-pre_self_attention_norm_global-scale": scale_rmsnorm_layer,
        "params-decoder-layers-post_self_attention_norm_global-scale": scale_rmsnorm_layer,
        "params-decoder-layers-post_ffw_norm_global-scale": scale_rmsnorm_layer,
        "params-decoder-layers-pre_ffw_norm_global-scale": scale_rmsnorm_layer,
        "params-decoder-layers-pre_self_attention_norm_local-scale": scale_rmsnorm_layer,
        "params-decoder-layers-post_self_attention_norm_local-scale": scale_rmsnorm_layer,
        "params-decoder-layers-post_ffw_norm_local-scale": scale_rmsnorm_layer,
        "params-decoder-layers-pre_ffw_norm_local-scale": scale_rmsnorm_layer,
    }
  else:
    for maxtext_layer_idx in range(nlayers // 2):
      mapping = {
          **mapping,
          f"params-decoder-layers_{maxtext_layer_idx}-self_attention_global-query-kernel": [
              reshape_kernel,
              scale_query_layer,
          ],
          f"params-decoder-layers_{maxtext_layer_idx}-self_attention_local-query-kernel": [
              reshape_kernel,
              scale_query_layer,
          ],
          f"params-decoder-layers_{maxtext_layer_idx}-self_attention_global-key-kernel": reshape_kernel,
          f"params-decoder-layers_{maxtext_layer_idx}-self_attention_local-key-kernel": reshape_kernel,
          f"params-decoder-layers_{maxtext_layer_idx}-self_attention_global-value-kernel": reshape_kernel,
          f"params-decoder-layers_{maxtext_layer_idx}-self_attention_local-value-kernel": reshape_kernel,
          f"params-decoder-layers_{maxtext_layer_idx}-mlp_global-wo-kernel": reshape_kernel,
          f"params-decoder-layers_{maxtext_layer_idx}-mlp_global-wi_1-kernel": reshape_kernel,
          f"params-decoder-layers_{maxtext_layer_idx}-mlp_global-wi_0-kernel": reshape_kernel,
          f"params-decoder-layers_{maxtext_layer_idx}-self_attention_global-out-kernel": reshape_kernel,
          f"params-decoder-layers_{maxtext_layer_idx}-mlp_local-wo-kernel": reshape_kernel,
          f"params-decoder-layers_{maxtext_layer_idx}-mlp_local-wi_1-kernel": reshape_kernel,
          f"params-decoder-layers_{maxtext_layer_idx}-mlp_local-wi_0-kernel": reshape_kernel,
          f"params-decoder-layers_{maxtext_layer_idx}-self_attention_local-out-kernel": reshape_kernel,
          f"params-decoder-layers_{maxtext_layer_idx}-pre_self_attention_norm_global-scale": scale_rmsnorm_layer,
          f"params-decoder-layers_{maxtext_layer_idx}-post_self_attention_norm_global-scale": scale_rmsnorm_layer,
          f"params-decoder-layers_{maxtext_layer_idx}-post_ffw_norm_global-scale": scale_rmsnorm_layer,
          f"params-decoder-layers_{maxtext_layer_idx}-pre_ffw_norm_global-scale": scale_rmsnorm_layer,
          f"params-decoder-layers_{maxtext_layer_idx}-pre_self_attention_norm_local-scale": scale_rmsnorm_layer,
          f"params-decoder-layers_{maxtext_layer_idx}-post_self_attention_norm_local-scale": scale_rmsnorm_layer,
          f"params-decoder-layers_{maxtext_layer_idx}-post_ffw_norm_local-scale": scale_rmsnorm_layer,
          f"params-decoder-layers_{maxtext_layer_idx}-pre_ffw_norm_local-scale": scale_rmsnorm_layer,
      }
  return mapping


def QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False):
  """Returns mapping from MaxText to HuggingFace Qwen3 weight paths.

  This function generates a dictionary that maps parameter names from a MaxText
  Qwen3 checkpoint to their corresponding names in the Hugging Face format.
  It handles both dense and Mixture-of-Experts (MoE) model variants.

  Args:
    config (dict): Model configuration dictionary, including
      'num_hidden_layers' and optionally 'num_experts'.
    scan_layers (bool, optional): Whether the MaxText model uses scanned
      layers. Defaults to False.

  Returns:
    dict: A mapping from MaxText parameter names to Hugging Face parameter
      names. For scanned or MoE layers, the value may be a list or a nested
      list of names.
  """
  n_layers = config["num_hidden_layers"]
  num_experts = config.get("num_experts", 0)

  mapping = {
      "params-token_embedder-embedding": "model.embed_tokens.weight",
      "params-decoder-decoder_norm-scale": "model.norm.weight",
      "params-decoder-logits_dense-kernel": "lm_head.weight",
  }

  if scan_layers:
    # This block handles scanned layers for both dense and MoE models.
    mapping.update(
        {
            "params-decoder-layers-pre_self_attention_layer_norm-scale": [
                f"model.layers.{i}.input_layernorm.weight" for i in range(n_layers)
            ],
            "params-decoder-layers-self_attention-query-kernel": [
                f"model.layers.{i}.self_attn.q_proj.weight" for i in range(n_layers)
            ],
            "params-decoder-layers-self_attention-key-kernel": [
                f"model.layers.{i}.self_attn.k_proj.weight" for i in range(n_layers)
            ],
            "params-decoder-layers-self_attention-value-kernel": [
                f"model.layers.{i}.self_attn.v_proj.weight" for i in range(n_layers)
            ],
            "params-decoder-layers-self_attention-out-kernel": [
                f"model.layers.{i}.self_attn.o_proj.weight" for i in range(n_layers)
            ],
            "params-decoder-layers-self_attention-query_norm-scale": [
                f"model.layers.{i}.self_attn.q_norm.weight" for i in range(n_layers)
            ],
            "params-decoder-layers-self_attention-key_norm-scale": [
                f"model.layers.{i}.self_attn.k_norm.weight" for i in range(n_layers)
            ],
            "params-decoder-layers-post_self_attention_layer_norm-scale": [
                f"model.layers.{i}.post_attention_layernorm.weight" for i in range(n_layers)
            ],
        }
    )
    if num_experts > 1:
      # For scanned MoE, we create a nested list: [[e0_l0, e0_l1..], [e1_l0, e1_l1..]..]
      # This follows the (experts, layers, ...) tensor layout.
      mapping.update(
          {
              "params-decoder-layers-moe_block-gate-kernel": [
                  f"model.layers.{i}.mlp.gate.weight" for i in range(n_layers)
              ],
              "params-decoder-layers-moe_block-wi_0": [
                  [f"model.layers.{l}.mlp.experts.{e}.gate_proj.weight" for l in range(n_layers)]
                  for e in range(num_experts)
              ],
              "params-decoder-layers-moe_block-wi_1": [
                  [f"model.layers.{l}.mlp.experts.{e}.up_proj.weight" for l in range(n_layers)]
                  for e in range(num_experts)
              ],
              "params-decoder-layers-moe_block-wo": [
                  [f"model.layers.{l}.mlp.experts.{e}.down_proj.weight" for l in range(n_layers)]
                  for e in range(num_experts)
              ],
          }
      )
    else:  # Dense MLP
      mapping.update(
          {
              "params-decoder-layers-mlp-wi_0-kernel": [
                  f"model.layers.{i}.mlp.gate_proj.weight" for i in range(n_layers)
              ],
              "params-decoder-layers-mlp-wi_1-kernel": [f"model.layers.{i}.mlp.up_proj.weight" for i in range(n_layers)],
              "params-decoder-layers-mlp-wo-kernel": [f"model.layers.{i}.mlp.down_proj.weight" for i in range(n_layers)],
          }
      )
  else:  # unscanned layers
    for i in range(n_layers):
      # Common Attention and Norms
      # pylint: disable=line-too-long
      mapping.update(
          {
              f"params-decoder-layers_{i}-pre_self_attention_layer_norm-scale": f"model.layers.{i}.input_layernorm.weight",
              f"params-decoder-layers_{i}-self_attention-query-kernel": f"model.layers.{i}.self_attn.q_proj.weight",
              f"params-decoder-layers_{i}-self_attention-key-kernel": f"model.layers.{i}.self_attn.k_proj.weight",
              f"params-decoder-layers_{i}-self_attention-value-kernel": f"model.layers.{i}.self_attn.v_proj.weight",
              f"params-decoder-layers_{i}-self_attention-out-kernel": f"model.layers.{i}.self_attn.o_proj.weight",
              f"params-decoder-layers_{i}-self_attention-query_norm-scale": f"model.layers.{i}.self_attn.q_norm.weight",
              f"params-decoder-layers_{i}-self_attention-key_norm-scale": f"model.layers.{i}.self_attn.k_norm.weight",
              f"params-decoder-layers_{i}-post_self_attention_layer_norm-scale": f"model.layers.{i}.post_attention_layernorm.weight",
              f"params-decoder-layers_{i}-post_self_attention_layer_norm-scale": f"model.layers.{i}.post_attention_layernorm.weight",
          }
      )
      if num_experts > 1:
        # For each unscanned MoE layer, map the MaxText parameter to a 1D list of all expert weights for that layer.
        mapping.update(
            {
                f"params-decoder-layers_{i}-moe_block-gate-kernel": f"model.layers.{i}.mlp.gate.weight",
                f"params-decoder-layers_{i}-moe_block-wi_0": [
                    f"model.layers.{i}.mlp.experts.{j}.gate_proj.weight" for j in range(num_experts)
                ],
                f"params-decoder-layers_{i}-moe_block-wi_1": [
                    f"model.layers.{i}.mlp.experts.{j}.up_proj.weight" for j in range(num_experts)
                ],
                f"params-decoder-layers_{i}-moe_block-wo": [
                    f"model.layers.{i}.mlp.experts.{j}.down_proj.weight" for j in range(num_experts)
                ],
            }
        )
      else:  # Dense MLP
        mapping.update(
            {
                f"params-decoder-layers_{i}-mlp-wi_0-kernel": f"model.layers.{i}.mlp.gate_proj.weight",
                f"params-decoder-layers_{i}-mlp-wi_1-kernel": f"model.layers.{i}.mlp.up_proj.weight",
                f"params-decoder-layers_{i}-mlp-wo-kernel": f"model.layers.{i}.mlp.down_proj.weight",
            }
        )
  return mapping


def QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN(config, maxtext_config, scan_layers=False, saving_to_hf=False):
  """Creates parameter transformation functions for Qwen3.

  This function provides a dictionary of transformation functions (hooks) for
  converting Qwen3 model parameters between MaxText and Hugging Face formats.
  It handles embedding padding and kernel reshaping.

  Args:
    config (dict): Model configuration dictionary, including
      'num_hidden_layers' and optionally 'num_experts'.
    scan_layers (bool, optional): Whether the model uses scanned layers.
      Defaults to False.
    saving_to_hf (bool, optional): The direction of conversion. True for
      MaxText to Hugging Face, False for the reverse. Defaults to False.

  Returns:
    dict: A dictionary mapping MaxText parameter names to their corresponding
      transformation functions.
  """
  n_layers = config["num_hidden_layers"]
  num_experts = config.get("num_experts", 0)

  def pad_embedding_layer(input_tensor, target_shape):
    """Pads or truncates embedding layer to match target vocab size."""
    source_vocab_size = input_tensor.shape[0]
    target_vocab_size = target_shape[0]

    if source_vocab_size == target_vocab_size:
      return input_tensor

    if saving_to_hf:  # MaxText to HF, truncate
      return input_tensor[:target_vocab_size, :]
    else:  # HF to MaxText, pad
      padded_tensor = np.zeros(target_shape, dtype=input_tensor.dtype)
      padded_tensor[:source_vocab_size, :] = input_tensor
      return padded_tensor

  def reshape_kernel(input_tensor, target_shape):
    """Reshapes and transposes kernel weights between MaxText and HF."""
    if saving_to_hf:
      flipped_target_shape = np.flip(np.array(target_shape))
      return input_tensor.reshape(flipped_target_shape).T
    else:
      return input_tensor.T.reshape(target_shape)

  mapping = {
      "params-token_embedder-embedding": pad_embedding_layer,
      "params-decoder-logits_dense-kernel": reshape_kernel,
  }

  kernel_hooks = [
      "self_attention-query-kernel",
      "self_attention-key-kernel",
      "self_attention-value-kernel",
      "self_attention-out-kernel",
      "mlp-wi_0-kernel",
      "mlp-wi_1-kernel",
      "mlp-wo-kernel",
  ]
  moe_kernel_hooks = [
      "moe_block-gate-kernel",
      "moe_block-wi_0-kernel",
      "moe_block-wi_1-kernel",
      "moe_block-wo-kernel",
      "moe_block-wi_0",
      "moe_block-wi_1",
      "moe_block-wo",
  ]

  if scan_layers:
    for key in kernel_hooks:
      mapping[f"params-decoder-layers-{key}"] = reshape_kernel
    if num_experts > 1:
      for key in moe_kernel_hooks:
        mapping[f"params-decoder-layers-{key}"] = reshape_kernel
  else:
    for i in range(n_layers):
      for key in kernel_hooks:
        mapping[f"params-decoder-layers_{i}-{key}"] = reshape_kernel
      if num_experts > 1:
        for key in moe_kernel_hooks:
          mapping[f"params-decoder-layers_{i}-{key}"] = reshape_kernel
  return mapping


def DEEPSEEK_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False):
  """Returns mapping from MaxText to HuggingFace Deepseek weight paths using f-strings."""
  # TODO(shuningjin): add unscan support, b/457820735
  if not scan_layers:
    raise NotImplementedError("This conversion only supports scanned MaxText models.")

  # Extract hf configuration parameters, without mtp
  num_main_layers = config["num_hidden_layers"]
  first_num_dense_layers = config["first_k_dense_replace"]
  num_experts = config.get("n_routed_experts", 0)

  # Mapping for non-layer-specific weights
  mapping = {
      "params-token_embedder-embedding": "model.embed_tokens.weight",
      "params-decoder-decoder_norm-scale": "model.norm.weight",
      "params-decoder-logits_dense-kernel": "lm_head.weight",
  }
  # Attention keys are shared by both dense and MoE
  attention_keys = {
      "pre_self_attention_layer_norm-scale": "input_layernorm.weight",
      "post_self_attention_layer_norm-scale": "post_attention_layernorm.weight",
      "self_attention-wkv_a-kernel": "self_attn.kv_a_proj_with_mqa.weight",
      "self_attention-kv_norm-scale": "self_attn.kv_a_layernorm.weight",
      "self_attention-wkv_b-kernel": "self_attn.kv_b_proj.weight",
      "self_attention-out-kernel": "self_attn.o_proj.weight",
      # v3
      "self_attention-wq_a-kernel": "self_attn.q_a_proj.weight",
      "self_attention-q_norm-scale": "self_attn.q_a_layernorm.weight",
      "self_attention-wq_b-kernel": "self_attn.q_b_proj.weight",
      # v2
      "self_attention-query-kernel": "self_attn.q_proj.weight",
  }
  # Dense Layers
  dense_layer_keys = attention_keys | {
      "mlp-wi_0-kernel": "mlp.gate_proj.weight",
      "mlp-wi_1-kernel": "mlp.up_proj.weight",
      "mlp-wo-kernel": "mlp.down_proj.weight",
  }
  for maxtext_key, hf_key in dense_layer_keys.items():
    mapping[f"params-decoder-dense_layers-{maxtext_key}"] = [
        f"model.layers.{i}.{hf_key}" for i in range(first_num_dense_layers)
    ]

  # MoE Layers
  moe_layer_keys = attention_keys | {
      "DeepSeekMoeBlock_0-shared_experts-wi_0-kernel": "mlp.shared_experts.gate_proj.weight",
      "DeepSeekMoeBlock_0-shared_experts-wi_1-kernel": "mlp.shared_experts.up_proj.weight",
      "DeepSeekMoeBlock_0-shared_experts-wo-kernel": "mlp.shared_experts.down_proj.weight",
      "DeepSeekMoeBlock_0-MoeBlock_0-gate-kernel": "mlp.gate.weight",
      # v3
      "DeepSeekMoeBlock_0-MoeBlock_0-gate-bias": "mlp.gate.e_score_correction_bias",
  }
  for maxtext_key, hf_key in moe_layer_keys.items():
    mapping[f"params-decoder-moe_layers-{maxtext_key}"] = [
        f"model.layers.{i}.{hf_key}" for i in range(first_num_dense_layers, num_main_layers)
    ]

  # MoE Experts (nested list mapping: [[e0_l0, e0_l1..], [e1_l0, e1_l1..]..])
  moe_expert_keys = {
      "DeepSeekMoeBlock_0-MoeBlock_0-wi_0": "gate_proj.weight",
      "DeepSeekMoeBlock_0-MoeBlock_0-wi_1": "up_proj.weight",
      "DeepSeekMoeBlock_0-MoeBlock_0-wo": "down_proj.weight",
  }
  for maxtext_key, hf_key in moe_expert_keys.items():
    mapping[f"params-decoder-moe_layers-{maxtext_key}"] = [
        [f"model.layers.{l}.mlp.experts.{e}.{hf_key}" for l in range(first_num_dense_layers, num_main_layers)]
        for e in range(num_experts)
    ]
  return mapping


def DEEPSEEK_MAXTEXT_TO_HF_PARAM_HOOK_FN(config, maxtext_config, scan_layers=False, saving_to_hf=False):
  """Creates parameter transformation functions for Deepseek using f-strings."""
  # TODO(shuningjin): support hf->orbax(scan), b/457820372
  if not saving_to_hf:
    raise NotImplementedError("This conversion only supports saving_to_hf")
  # TODO(shuningjin): add unscan support, b/457820735
  if not scan_layers:
    raise NotImplementedError("This conversion only supports scanned MaxText models.")

  def reshape_kernel(input_tensor, target_shape):
    """Reshapes and transposes kernel weights between MaxText and HF."""
    if saving_to_hf:
      flipped_target_shape = np.flip(np.array(target_shape))
      return input_tensor.reshape(flipped_target_shape).T
    else:
      return input_tensor.T.reshape(target_shape)

  mapping = {
      "params-decoder-logits_dense-kernel": reshape_kernel,
  }
  # all keys that need the reshape hook
  params_need_reshape = {
      # Dense Layers
      "params-decoder-dense_layers-self_attention-query-kernel",
      "params-decoder-dense_layers-self_attention-wq_a-kernel",
      "params-decoder-dense_layers-self_attention-wq_b-kernel",
      "params-decoder-dense_layers-self_attention-wkv_a-kernel",
      "params-decoder-dense_layers-self_attention-wkv_b-kernel",
      "params-decoder-dense_layers-self_attention-out-kernel",
      "params-decoder-dense_layers-mlp-wi_0-kernel",
      "params-decoder-dense_layers-mlp-wi_1-kernel",
      "params-decoder-dense_layers-mlp-wo-kernel",
      # MoE Layers
      "params-decoder-moe_layers-self_attention-query-kernel",
      "params-decoder-moe_layers-self_attention-wq_a-kernel",
      "params-decoder-moe_layers-self_attention-wq_b-kernel",
      "params-decoder-moe_layers-self_attention-wkv_a-kernel",
      "params-decoder-moe_layers-self_attention-wkv_b-kernel",
      "params-decoder-moe_layers-self_attention-out-kernel",
      "params-decoder-moe_layers-DeepSeekMoeBlock_0-shared_experts-wi_0-kernel",
      "params-decoder-moe_layers-DeepSeekMoeBlock_0-shared_experts-wi_1-kernel",
      "params-decoder-moe_layers-DeepSeekMoeBlock_0-shared_experts-wo-kernel",
      "params-decoder-moe_layers-DeepSeekMoeBlock_0-MoeBlock_0-gate-kernel",
      "params-decoder-moe_layers-DeepSeekMoeBlock_0-MoeBlock_0-wi_0",
      "params-decoder-moe_layers-DeepSeekMoeBlock_0-MoeBlock_0-wi_1",
      "params-decoder-moe_layers-DeepSeekMoeBlock_0-MoeBlock_0-wo",
  }

  for key in params_need_reshape:
    mapping[key] = reshape_kernel
  return mapping


def DEEPSEEK_NNX_TO_VLLM_PARAM_HOOK_FN():
  """Creates parameter transformation functions for Deepseek."""
  return {}


def GPT_OSS_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False):
  """Returns mapping from MaxText gpt-oss to Hugging Face weight paths.

  Handles the inhomogeneous scan block structure (inhomogeneous_layer_cycle_interval)

  Handles N-to-1 mapping from maxtext to huggingface
  - (GptOssMlp-wi_0, GptOssMlp-wi_1): mlp.experts.gate_up_proj
  - (GptOssMlp-wi_0_bias, GptOssMlp-wi_1_bias): mlp.experts.gate_up_proj_bias
  """
  # TODO(shuningjin): add unscan support, b/459541579
  if not scan_layers:
    raise NotImplementedError("Current gpt-oss mapping only supports scan_layers=True")

  n_layers = config["num_hidden_layers"]  # hf config
  layer_cycle_interval = maxtext_config.inhomogeneous_layer_cycle_interval

  # Base mapping for non-layer parameters (targeting standard HF keys)
  mapping = {
      "params-token_embedder-embedding": "model.embed_tokens.weight",
      "params-decoder-decoder_norm-scale": "model.norm.weight",
      "params-decoder-logits_dense-kernel": "lm_head.weight",
  }

  for block_idx in range(layer_cycle_interval):
    # Identify all original HF layer indices that collapse into this block
    hf_indices = list(range(block_idx, n_layers, layer_cycle_interval))
    prefix = f"params-decoder-layers-layers_{block_idx}"

    # Layer Norms
    mapping[f"{prefix}-pre_self_attention_layer_norm-scale"] = [
        f"model.layers.{i}.input_layernorm.weight" for i in hf_indices
    ]
    mapping[f"{prefix}-post_self_attention_layer_norm-scale"] = [
        f"model.layers.{i}.post_attention_layernorm.weight" for i in hf_indices
    ]

    # GptOssAttention
    mapping.update(
        {
            f"{prefix}-GptOssAttention-query-kernel": [f"model.layers.{i}.self_attn.q_proj.weight" for i in hf_indices],
            f"{prefix}-GptOssAttention-query-bias": [f"model.layers.{i}.self_attn.q_proj.bias" for i in hf_indices],
            f"{prefix}-GptOssAttention-key-kernel": [f"model.layers.{i}.self_attn.k_proj.weight" for i in hf_indices],
            f"{prefix}-GptOssAttention-key-bias": [f"model.layers.{i}.self_attn.k_proj.bias" for i in hf_indices],
            f"{prefix}-GptOssAttention-value-kernel": [f"model.layers.{i}.self_attn.v_proj.weight" for i in hf_indices],
            f"{prefix}-GptOssAttention-value-bias": [f"model.layers.{i}.self_attn.v_proj.bias" for i in hf_indices],
            f"{prefix}-GptOssAttention-out-kernel": [f"model.layers.{i}.self_attn.o_proj.weight" for i in hf_indices],
            f"{prefix}-GptOssAttention-out-bias": [f"model.layers.{i}.self_attn.o_proj.bias" for i in hf_indices],
            f"{prefix}-GptOssAttention-sinks": [f"model.layers.{i}.self_attn.sinks" for i in hf_indices],
        }
    )

    # GptOssMlp
    # 1. Gate/Router
    mapping.update(
        {
            f"{prefix}-GptOssMlp-gate-kernel": [f"model.layers.{i}.mlp.router.weight" for i in hf_indices],
            f"{prefix}-GptOssMlp-gate-bias": [f"model.layers.{i}.mlp.router.bias" for i in hf_indices],
        }
    )

    # 2. Experts (Down Projection)
    mapping.update(
        {
            f"{prefix}-GptOssMlp-wo": [f"model.layers.{i}.mlp.experts.down_proj" for i in hf_indices],
            f"{prefix}-GptOssMlp-wo_bias": [f"model.layers.{i}.mlp.experts.down_proj_bias" for i in hf_indices],
        }
    )

    # 3. Experts (Gate/Up Fused Projection)
    # N-to-1 mapping
    mapping.update(
        {
            (f"{prefix}-GptOssMlp-wi_0", f"{prefix}-GptOssMlp-wi_1"): [
                f"model.layers.{i}.mlp.experts.gate_up_proj" for i in hf_indices
            ],
            (f"{prefix}-GptOssMlp-wi_0_bias", f"{prefix}-GptOssMlp-wi_1_bias"): [
                f"model.layers.{i}.mlp.experts.gate_up_proj_bias" for i in hf_indices
            ],
        }
    )

  return mapping


def GPT_OSS_TO_HF_PARAM_HOOK_FN(config, maxtext_config, scan_layers=False, saving_to_hf=False):
  """Transformation hooks for gpt-oss parameters.

  Handles the inhomogeneous scan block structure (inhomogeneous_layer_cycle_interval)

  Handles N-to-1 mapping from maxtext to huggingface
  - (GptOssMlp-wi_0, GptOssMlp-wi_1): mlp.experts.gate_up_proj
  - (GptOssMlp-wi_0_bias, GptOssMlp-wi_1_bias): mlp.experts.gate_up_proj_bias
  """
  # TODO(shuningjin): support hf->orbax(scan), b/459541579
  if not saving_to_hf:
    raise NotImplementedError("Currently gpt-oss only supports saving_to_hf=True.")
  # TODO(shuningjin): add unscan support, b/459541579
  if not scan_layers:
    raise NotImplementedError("Currently gpt-oss only supports scan_layers=True.")

  def transpose(input_tensor, target_shape=None):
    if saving_to_hf:
      return input_tensor.T
    else:
      return input_tensor.T

  def reshape_kernel(input_tensor, target_shape):
    """Reshapes and transposes kernel weights between MaxText and HF."""
    if saving_to_hf:
      flipped_target_shape = np.flip(np.array(target_shape))
      return input_tensor.reshape(flipped_target_shape).T
    else:
      return input_tensor.T.reshape(target_shape)

  def reshape_bias(input_tensor, target_shape=None):
    """Reshapes biases between MaxText 2D (heads, dim) and HF 1D (hidden)."""
    if saving_to_hf:
      # MaxText [heads, head_dim] -> HF [hidden_dim] (flatten)
      return input_tensor.reshape(target_shape)
    else:
      # HF [hidden_dim] -> MaxText [heads, head_dim]
      return input_tensor.reshape(target_shape)

  def interleave(input_tensor, target_shape=None):
    """
    N-to-1 mapping: maxtext (wi_0, wi_1) <-> hf (wi_0_1)
    if saving_to_hf, input_tensor is a list of tensors
    """
    if saving_to_hf:
      # (wi_0, wi_1) -> wi_0_1
      wi_0, wi_1 = input_tensor
      wi_0_1 = np.empty(target_shape, dtype=wi_0.dtype)
      wi_0_1[..., ::2] = wi_0
      wi_0_1[..., 1::2] = wi_1
      return wi_0_1
    else:
      # wi_0_1 -> (wi_0, wi_1)
      # TODO(shuningjin): support hf->orbax(scan), b/459541579
      raise NotImplementedError

  hooks = {
      "params-decoder-logits_dense-kernel": transpose,
  }

  # Scan over blocks
  layer_cycle_interval = maxtext_config.inhomogeneous_layer_cycle_interval
  for block_idx in range(layer_cycle_interval):
    prefix = f"params-decoder-layers-layers_{block_idx}"
    # Attention Kernels & Biases
    for key in ["query", "key", "value"]:
      hooks[f"{prefix}-GptOssAttention-{key}-kernel"] = reshape_kernel
      hooks[f"{prefix}-GptOssAttention-{key}-bias"] = reshape_bias

    hooks[f"{prefix}-GptOssAttention-out-kernel"] = reshape_kernel

    # MLP Kernels & Biases
    hooks[f"{prefix}-GptOssMlp-gate-kernel"] = transpose
    # Experts (Gate/Up Fused Projection), N-to-1 mapping
    hooks[(f"{prefix}-GptOssMlp-wi_0", f"{prefix}-GptOssMlp-wi_1")] = interleave
    hooks[(f"{prefix}-GptOssMlp-wi_0_bias", f"{prefix}-GptOssMlp-wi_1_bias")] = interleave

  return hooks


def QWEN3_OMNI_MOE_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False):
  """Returns mapping from MaxText to HuggingFace Qwen3-Omni weight paths.

  This function combines mappings from different modalities (text, vision, audio, etc.)
  into a unified parameter mapping for the multi-modal Qwen3-Omni model.

  Args:
    config (dict): Model configuration dictionary containing modality-specific configs.
    scan_layers (bool, optional): Whether the model uses scanned layers. Defaults to False.

  Returns:
    dict: Combined mapping from all modalities.
  """
  # Collect all modality mappings
  mapping = {}

  # Text mapping with "thinker." prefix, reusing QWEN3-MOE mapping function
  num_experts_text = config["thinker_config"]["text_config"].get("num_experts", 0)
  n_layers_text = config["thinker_config"]["text_config"]["num_hidden_layers"]
  text_mapping = QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING(
      config={"num_hidden_layers": n_layers_text, "num_experts": num_experts_text},
      maxtext_config=maxtext_config,
      scan_layers=scan_layers,
  )

  # Add "thinker." prefix to text mapping values
  for key, value in text_mapping.items():
    text_mapping[key] = [f"thinker.{v}" for v in value] if isinstance(value, list) else f"thinker.{value}"
  mapping.update(text_mapping)

  # TODO(hengtaoguo): Add vision, audio, and other modality mappings here similarly
  # mapping.update(vision_mapping), mapping.update(audio_mapping), etc.

  return mapping


def QWEN3_OMNI_MOE_MAXTEXT_TO_HF_PARAM_HOOK_FN(config, maxtext_config, scan_layers=False, saving_to_hf=False):
  """Creates parameter transformation functions for Qwen3-Omni.

  This function provides a dictionary of transformation functions (hooks) for
  converting Qwen3-Omni model parameters between MaxText and Hugging Face formats.
  It handles embedding padding and kernel reshaping.

  Args:
    config (dict): Model configuration dictionary, including
      'num_hidden_layers' and optionally 'num_experts'.
    scan_layers (bool, optional): Whether the model uses scanned layers.
      Defaults to False.
    saving_to_hf (bool, optional): The direction of conversion. True for
      MaxText to Hugging Face, False for the reverse. Defaults to False.

  Returns:
    dict: A dictionary mapping MaxText parameter names to their corresponding
      transformation functions.
  """
  # Collect all modality hooks
  mapping = {}

  # Text hooks, reusing QWEN3-MOE hook function
  num_experts_text = config["thinker_config"]["text_config"].get("num_experts", 0)
  n_layers_text = config["thinker_config"]["text_config"]["num_hidden_layers"]
  text_hooks = QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN(
      config={"num_hidden_layers": n_layers_text, "num_experts": num_experts_text},
      maxtext_config=maxtext_config,
      scan_layers=scan_layers,
      saving_to_hf=saving_to_hf,
  )
  mapping.update(text_hooks)

  # TODO(hengtaoguo): Add vision, audio, and other modality mappings here similarly
  # mapping.update(vision_hooks), mapping.update(audio_hooks), etc.

  return mapping


def QWEN3_NNX_TO_VLLM_PARAM_HOOK_FN(target_shape=None):
  """Creates parameter transformation functions for Qwen3.

  This function provides a dictionary of transformation functions (hooks) for
  converting Qwen3 model parameters between NNX and vLLM formats.

  Returns:
    dict: A dictionary mapping NNX parameter names to their corresponding
      transformation functions.
  """
  return {}


def LLAMA31_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False):
  """
  Returns a dictionary mapping from MaxText parameter names to
  HuggingFace LLaMA3.1 parameter names.

  Args:
      config (dict): Model configuration dictionary containing:
          - num_hidden_layers (int): The number of decoder layers.
      scan_layers (bool, optional): If True, MaxText layers are 'stacked'
          into a single param. Defaults to False.

  Returns:
      dict: A mapping from MaxText parameter names to HF parameter names (str)
            or lists of names (if scan_layers=True).
  """
  n_layers = config["num_hidden_layers"]

  mapping = {
      "params-token_embedder-embedding": "model.embed_tokens.weight",
      "params-decoder-logits_dense-kernel": "lm_head.weight",
      "params-decoder-decoder_norm-scale": "model.norm.weight",
  }

  if scan_layers:
    mapping["params-decoder-layers-self_attention-query-kernel"] = [
        f"model.layers.{layer_idx}.self_attn.q_proj.weight" for layer_idx in range(n_layers)
    ]
    mapping["params-decoder-layers-self_attention-key-kernel"] = [
        f"model.layers.{layer_idx}.self_attn.k_proj.weight" for layer_idx in range(n_layers)
    ]
    mapping["params-decoder-layers-self_attention-value-kernel"] = [
        f"model.layers.{layer_idx}.self_attn.v_proj.weight" for layer_idx in range(n_layers)
    ]
    mapping["params-decoder-layers-self_attention-out-kernel"] = [
        f"model.layers.{layer_idx}.self_attn.o_proj.weight" for layer_idx in range(n_layers)
    ]
    mapping["params-decoder-layers-mlp-wi_0-kernel"] = [
        f"model.layers.{layer_idx}.mlp.gate_proj.weight" for layer_idx in range(n_layers)
    ]
    mapping["params-decoder-layers-mlp-wi_1-kernel"] = [
        f"model.layers.{layer_idx}.mlp.up_proj.weight" for layer_idx in range(n_layers)
    ]
    mapping["params-decoder-layers-mlp-wo-kernel"] = [
        f"model.layers.{layer_idx}.mlp.down_proj.weight" for layer_idx in range(n_layers)
    ]
    mapping["params-decoder-layers-pre_self_attention_layer_norm-scale"] = [
        f"model.layers.{layer_idx}.input_layernorm.weight" for layer_idx in range(n_layers)
    ]
    mapping["params-decoder-layers-post_self_attention_layer_norm-scale"] = [
        f"model.layers.{layer_idx}.post_attention_layernorm.weight" for layer_idx in range(n_layers)
    ]
  else:
    for layer_idx in range(n_layers):
      mapping[f"params-decoder-layers_{layer_idx}-self_attention-query-kernel"] = (
          f"model.layers.{layer_idx}.self_attn.q_proj.weight"
      )
      mapping[f"params-decoder-layers_{layer_idx}-self_attention-key-kernel"] = (
          f"model.layers.{layer_idx}.self_attn.k_proj.weight"
      )
      mapping[f"params-decoder-layers_{layer_idx}-self_attention-value-kernel"] = (
          f"model.layers.{layer_idx}.self_attn.v_proj.weight"
      )
      mapping[f"params-decoder-layers_{layer_idx}-self_attention-out-kernel"] = (
          f"model.layers.{layer_idx}.self_attn.o_proj.weight"
      )
      mapping[f"params-decoder-layers_{layer_idx}-mlp-wi_0-kernel"] = f"model.layers.{layer_idx}.mlp.gate_proj.weight"
      mapping[f"params-decoder-layers_{layer_idx}-mlp-wi_1-kernel"] = f"model.layers.{layer_idx}.mlp.up_proj.weight"
      mapping[f"params-decoder-layers_{layer_idx}-mlp-wo-kernel"] = f"model.layers.{layer_idx}.mlp.down_proj.weight"
      mapping[f"params-decoder-layers_{layer_idx}-pre_self_attention_layer_norm-scale"] = (
          f"model.layers.{layer_idx}.input_layernorm.weight"
      )
      mapping[f"params-decoder-layers_{layer_idx}-post_self_attention_layer_norm-scale"] = (
          f"model.layers.{layer_idx}.post_attention_layernorm.weight"
      )

  return mapping


def LLAMA31_MAXTEXT_TO_HF_PARAM_HOOK_FN(config, maxtext_config, scan_layers=False, saving_to_hf=False):
  """Creates parameter transformation functions for converting between MaxText and
  HuggingFace formats.

  This function generates a mapping of transformation functions that handle the necessary
  conversions between MaxText and HuggingFace parameter formats, including operations like
  reshaping.
  """
  nlayers = config["num_hidden_layers"]

  def scale_query_layer(input_tensor, target_shape):
    def to_hf():
      depth_scale = np.dtype("float32").type(np.sqrt(config["head_dim"]))

      original_dtype = input_tensor.dtype
      output_tensor = input_tensor.astype(np.float32) * depth_scale
      return output_tensor.astype(original_dtype)

    def from_hf():
      depth_scale = np.dtype("float32").type(1 / np.sqrt(config["head_dim"]))

      original_dtype = input_tensor.dtype
      output_tensor = input_tensor.astype(np.float32) * depth_scale
      return output_tensor.astype(original_dtype)

    if saving_to_hf:
      return to_hf()
    else:
      return from_hf()

  def adjust_rope(input_tensor, target_shape):
    def from_hf(arr):
      """Convert from HF's concatenated layout to MaxText's interleaved layout"""
      half_dim = arr.shape[-1] // 2
      first_half = arr[..., :half_dim]
      second_half = arr[..., half_dim:]
      return jax.numpy.stack([first_half, second_half], axis=-1).reshape(arr.shape)

    def to_hf(arr):
      """Convert from MaxText's interleaved layout to HF's concatenated layout"""
      evens = arr[..., ::2]
      odds = arr[..., 1::2]
      return jax.numpy.concatenate((evens, odds), axis=arr.ndim - 1)

    if saving_to_hf:
      return to_hf(input_tensor)
    else:
      return from_hf(input_tensor)

  def reshape_kernel(input_tensor, target_shape):
    def to_hf():
      flipped_target_shape = np.flip(np.array(target_shape))
      return input_tensor.reshape(flipped_target_shape).transpose()

    def from_hf():
      return input_tensor.transpose().reshape(target_shape)

    if saving_to_hf:
      return to_hf()
    else:
      return from_hf()

  query_hooks = [reshape_kernel, adjust_rope, scale_query_layer]
  key_hooks = [reshape_kernel, adjust_rope]

  if not saving_to_hf:
    query_hooks.reverse()
    key_hooks.reverse()

  hook_fns = {}

  hook_fns["params-decoder-logits_dense-kernel"] = reshape_kernel

  if scan_layers:
    hook_fns = {
        **hook_fns,
        "params-decoder-layers-self_attention-query-kernel": query_hooks,
        "params-decoder-layers-self_attention-key-kernel": key_hooks,
        "params-decoder-layers-self_attention-value-kernel": reshape_kernel,
        "params-decoder-layers-self_attention-out-kernel": reshape_kernel,
        "params-decoder-layers-mlp-wi_0-kernel": reshape_kernel,
        "params-decoder-layers-mlp-wi_1-kernel": reshape_kernel,
        "params-decoder-layers-mlp-wo-kernel": reshape_kernel,
    }
  else:
    for layer_idx in range(nlayers):
      hook_fns[f"params-decoder-layers_{layer_idx}-self_attention-query-kernel"] = query_hooks
      hook_fns[f"params-decoder-layers_{layer_idx}-self_attention-key-kernel"] = key_hooks
      hook_fns[f"params-decoder-layers_{layer_idx}-self_attention-value-kernel"] = reshape_kernel
      hook_fns[f"params-decoder-layers_{layer_idx}-self_attention-out-kernel"] = reshape_kernel
      hook_fns[f"params-decoder-layers_{layer_idx}-mlp-wi_0-kernel"] = reshape_kernel
      hook_fns[f"params-decoder-layers_{layer_idx}-mlp-wi_1-kernel"] = reshape_kernel
      hook_fns[f"params-decoder-layers_{layer_idx}-mlp-wo-kernel"] = reshape_kernel
  return hook_fns


def LLAMA31_NNX_TO_VLLM_PARAM_HOOK_FN():
  """Defines and returns hook functions for weight transformations.

  These hooks are applied to specific weights during the conversion
  from MaxText to a HuggingFace-compatible format. They handle
  transformations like RoPE reordering and query scaling that are not
  simple re-mappings.

  Returns:
    A dictionary where keys are MaxText parameter names and values are
    the corresponding transformation functions.
  """

  def reorder_rope(arr):
    """Reorders Rotary Position Embedding (RoPE) weights.

    This function is necessary because MaxText and HuggingFace's vLLM
    implementations may have different orderings for RoPE dimensions.
    It splits the last dimension into even and odd indices and
    concatenates them.

    Args:
      arr: The input weight array.

    Returns:
      The reordered weight array.
    """
    evens = arr[..., ::2]
    odds = arr[..., 1::2]
    return jax.numpy.concatenate((evens, odds), axis=arr.ndim - 1)

  def transform_query_kernel(arr):
    """Transforms the query kernel.

    This involves scaling the kernel by the square root of the head
    dimension and then applying RoPE reordering.

    Args:
      arr: The query kernel weight array.

    Returns:
      The transformed query kernel array.
    """
    head_dim = arr.shape[-1]
    depth_scale = np.dtype("float32").type(np.sqrt(head_dim))
    arr = arr * depth_scale
    return reorder_rope(arr)

  hook_fns = {
      "base.decoder.layers.self_attention.query.kernel": transform_query_kernel,
      "base.decoder.layers.self_attention.key.kernel": reorder_rope,
  }
  return hook_fns


# {maxtext model name: {maxtext weight name: hf weight name}}
PARAM_MAPPING = {
    "gemma2-2b": GEMMA2_MAXTEXT_TO_HF_PARAM_MAPPING,
    "gemma2-9b": GEMMA2_MAXTEXT_TO_HF_PARAM_MAPPING,
    "gemma2-27b": GEMMA2_MAXTEXT_TO_HF_PARAM_MAPPING,
    "gemma3-4b": GEMMA3_MAXTEXT_TO_HF_PARAM_MAPPING,
    "gemma3-12b": GEMMA3_MAXTEXT_TO_HF_PARAM_MAPPING,
    "gemma3-27b": GEMMA3_MAXTEXT_TO_HF_PARAM_MAPPING,
    "qwen3-0.6b": QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING,
    "qwen3-4b": QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING,
    "qwen3-4b-thinking-2507": QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING,
    "qwen3-8b": QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING,
    "qwen3-14b": QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING,
    "qwen3-32b": QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING,
    "llama3.1-8b": LLAMA31_MAXTEXT_TO_HF_PARAM_MAPPING,
    "llama3.1-70b": LLAMA31_MAXTEXT_TO_HF_PARAM_MAPPING,
    "llama3.1-405b": LLAMA31_MAXTEXT_TO_HF_PARAM_MAPPING,
    "qwen3-30b-a3b": QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING,
    "qwen3-235b-a22b": QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING,
    "qwen3-coder-480b-a35b": QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING,
    "deepseek3-671b": DEEPSEEK_MAXTEXT_TO_HF_PARAM_MAPPING,
    "gpt-oss-20b": GPT_OSS_MAXTEXT_TO_HF_PARAM_MAPPING,
    "gpt-oss-120b": GPT_OSS_MAXTEXT_TO_HF_PARAM_MAPPING,
    "qwen3-omni-30b-a3b": QWEN3_OMNI_MOE_MAXTEXT_TO_HF_PARAM_MAPPING,
}

# {maxtext model name: {maxtext weight name: bi-directional transform}}
HOOK_FNS = {
    "gemma2-2b": GEMMA2_MAXTEXT_TO_HF_PARAM_HOOK_FN,
    "gemma2-9b": GEMMA2_MAXTEXT_TO_HF_PARAM_HOOK_FN,
    "gemma2-27b": GEMMA2_MAXTEXT_TO_HF_PARAM_HOOK_FN,
    "gemma3-4b": GEMMA3_MAXTEXT_TO_HF_PARAM_HOOK_FN,
    "gemma3-12b": GEMMA3_MAXTEXT_TO_HF_PARAM_HOOK_FN,
    "gemma3-27b": GEMMA3_MAXTEXT_TO_HF_PARAM_HOOK_FN,
    "qwen3-0.6b": QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN,
    "qwen3-4b": QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN,
    "qwen3-4b-thinking-2507": QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN,
    "qwen3-8b": QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN,
    "qwen3-14b": QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN,
    "qwen3-32b": QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN,
    "llama3.1-8b": LLAMA31_MAXTEXT_TO_HF_PARAM_HOOK_FN,
    "llama3.1-70b": LLAMA31_MAXTEXT_TO_HF_PARAM_HOOK_FN,
    "llama3.1-405b": LLAMA31_MAXTEXT_TO_HF_PARAM_HOOK_FN,
    "qwen3-30b-a3b": QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN,
    "qwen3-235b-a22b": QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN,
    "qwen3-coder-480b-a35b": QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN,
    "deepseek3-671b": DEEPSEEK_MAXTEXT_TO_HF_PARAM_HOOK_FN,
    "gpt-oss-20b": GPT_OSS_TO_HF_PARAM_HOOK_FN,
    "gpt-oss-120b": GPT_OSS_TO_HF_PARAM_HOOK_FN,
    "qwen3-omni-30b-a3b": QWEN3_OMNI_MOE_MAXTEXT_TO_HF_PARAM_HOOK_FN,
}

VLLM_HOOK_FNS = {
    "qwen3": QWEN3_NNX_TO_VLLM_PARAM_HOOK_FN,
    "llama3.1": LLAMA31_NNX_TO_VLLM_PARAM_HOOK_FN,
    "deepseek3-671b": DEEPSEEK_NNX_TO_VLLM_PARAM_HOOK_FN,
}
